using LinearAlgebra, NPZ, BenchmarkTools, Random

# Object Shape Legend
# ========================================= #
#=
    spins[N, numTemps]
    betas[numTemps]
    J[N, N]
    observable[histLen, numTemps]
    polledTimes[nTimes]
    observableQualityMeasure[numTemps, nTimes]
=#

# Misc functions
# ========================================= #

# Equivalent of numpy's geomspace
function geomspace(start, stop, n)

    vals = zeros(n)
    for i = 0:n-1
        vals[i + 1] = start^(1 - i/(n-1)) * stop^(i/(n-1))
    end
    return vals
end

# Peaks is a list of Gaussian peaks to add to the density. Each element of the
# list is itself a list of three parameters: [center pos, width, amplitude].
# So for example to increase the local density around the transition temp (assumed T=1)
# you could use a peaks = [ [1, 0.8, 2] ]. No problem to add multiple or negative peaks.
# By default (no peaks), a geomspace is recovered.
function tempMaker(t0, tf, numTemps; peaks=[])

    # Create a density distribution (finely spaced)
    numVals = 1000
    dist = ones(numVals)
    allTemps = range(log(t0), log(tf), length=numVals)

    # Modify the density
    numPeaks = length(peaks)
    for i = 1:numPeaks
        x0 = peaks[i][1]
        σ = peaks[i][2]
        A = peaks[i][3]
        dist .+= A*exp.(-(allTemps .- log(x0)).^2 / (2*σ^2))
    end

    # Compute cumulative distribution and normalize
    minVal = minimum(dist)
    if minVal < 0
        dist .-= minVal
    end
    intDist = cumsum(dist)
    intDist /= maximum(intDist)

    # Sample temps
    temps = zeros(numTemps)
    idx = 1
    for i = 1:numVals
        if intDist[i] >= (idx-1)/(numTemps-1) && idx <= numTemps
            temps[idx] = allTemps[i]
            idx += 1
        end
    end

    return exp.(temps)
end

# hist should be a 1D histogram from -1 to 1 with even bins.
function BinderCumulant(hist)

    numVals = size(hist)[1]
    xvals = range(-1, 1, length=numVals)

    mom2 = 0
    mom4 = 0
    total = sum(hist)
    for i = 1:numVals
        mom2 += hist[i]/total * xvals[i]^2
        mom4 += hist[i]/total * xvals[i]^4
    end

    return 1 - mom4/(3*mom2^2)
end

#= Local interaction in a confocal cavity. (x, y) are transverse coordinates
and ϕ is an exponential mode cutoff parameter. sx and sy are Gaussian std. deviations
of the atomic distribution for a single atom/ensemble.
=#
function D_local_even(x1, x2, y1, y2, ϕ, sx, sy)

    wx2 = 1 + 2*sx^2
    wy2 = 1 + 2*sy^2

    ϕx = log( (1 + 2*sx^2) / (1 - 2*sx^2) )
    ϕy = log( (1 + 2*sy^2) / (1 - 2*sy^2) )

    prefactor = 2/( π * wx2 * wy2 * sqrt( (1-exp(-2*(ϕ+ϕx))) * (1-exp(-2*(ϕ+ϕy))) ) )

    arg  = -(1-(exp(ϕx)-exp(-ϕ)) / (wx2*sinh(ϕ+ϕx))) * (x1^2+x2^2)/wx2
    arg += -(1-(exp(ϕy)-exp(-ϕ)) / (wy2*sinh(ϕ+ϕy))) * (y1^2+y2^2)/wy2

    argLoc = -exp(ϕx)*(x1-x2)^2/(wx2^2*sinh(ϕ+ϕx)) - exp(ϕy)*(y1-y2)^2/(wy2^2*sinh(ϕ+ϕy))

    argMir = -exp(ϕx)*(x1+x2)^2/(wx2^2*sinh(ϕ+ϕx)) - exp(ϕy)*(y1+y2)^2/(wy2^2*sinh(ϕ+ϕy))

    return prefactor*( exp(arg+argLoc) + exp(arg+argMir) )
end

# Nonlocal interaction in an even confocal cavity
function D_nonlocal_even(x1, x2, y1, y2, ϕ, sx, sy)

    wx2 = 1 + 2*sx^2
    wy2 = 1 + 2*sy^2

    ϕx = log((1+2*sx^2)/(1-2*sx^2))
    ϕy = log((1+2*sy^2)/(1-2*sy^2))

    prefactor = 4/(π * wx2 * wy2 * sqrt((1+exp(-2*(ϕ+ϕx)))*(1+exp(-2*(ϕ+ϕy)))) )

    arg  = -(1-exp(-ϕ)/(wx2*cosh(ϕ+ϕx))) * (x1^2+x2^2)/wx2
    arg += -(1-exp(-ϕ)/(wy2*cosh(ϕ+ϕy))) * (y1^2+y2^2)/wy2

    cosArg = 2*exp(ϕx)*x1*x2 / (wx2^2*cosh(ϕ+ϕx)) + 2*exp(ϕy)*y1*y2 / (wy2^2*cosh(ϕ+ϕy))

    return prefactor * (exp(arg+im*cosArg) + exp(arg-im*cosArg))/2
end

# J functions. The energy will always be -sum_{i <= j} J_ij f(s_i, s_j); this J_ij convention
# may differ from others by a factor of 2 and has caused us headaches in the past.
# ========================================= #

# Normalized to variance 1/N for extensive free energy
function SK_J(μ, N; rng=RandomDevice(), JScale=1.0, onDiag=false, useK=false, KFrac=0.5)

    J = randn(rng, N, N)
    J = (J + J') / sqrt(2*N)
    J .+= μ / N
    if !onDiag
        J = J - diagm(diag(J))
    end
    # Unfortunately due to the i <= j convention we need to halve the diagonal entries
    J -= diagm(diag(J)) / 2

    if useK
        K = KFrac * randn(rng, N, N)
        K = (K + K') / sqrt(2*N)
        K -= diagm(diag(K)) / 2
        return [J;;; K] .* JScale
    end

    return J .* JScale
end

function SK_Discrete_J(μ, N; rng=RandomDevice(), JScale=1.0, onDiag=false, useK=false, KFrac=0.5)
    if useK
        error("K matrix not implemented for this J function")
    end

    J = randn(rng, N, N)
    J = sign.(J + J') / sqrt(N)
    J .+= μ / N
    if !onDiag
        J = J - diagm(diag(J))
    end

    return J * JScale
end

# w is the standard deviation of the positions, in units of w0=35 um.
# This interaction describes the case where we are in an even confocal
# cavity with the atoms pinned longitudinally on the l+m = 0 mod 4 modes.
function confocal_cos_simple_J(w, N; rng=RandomDevice(), JScale=1.0, onDiag=false, useK=false, KFrac=0.5)
    if useK
        error("K matrix not implemented for this J function")
    end

    pos = w * randn(rng, N, 2)
    J = cos.( 2*(pos[:, 1]*pos[:, 1]' + pos[:, 2]*pos[:, 2]') )

    if !onDiag
        J = J - diagm(diag(J))
    end
    # Normalize by largest eigenvalue.
    J ./= maximum(abs.(eigen(J).values))

    return J .* JScale
end

# w is the standard deviation of the positions, in units of w0 = 35 um.
function confocal_sin_simple_J(w, N; rng=RandomDevice(), JScale=1.0, onDiag=false, useK=false, KFrac=0.5)
    if useK
        error("K matrix not implemented for this J function")
    end

    pos = w * randn(rng, N, 2)
    J = sin.( 2*(pos[:, 1]*pos[:, 1]' + pos[:, 2]*pos[:, 2]') )

    if !onDiag
        J = J - diagm(diag(J))
    end
    # Normalize by largest eigenvalue.
    J ./= maximum(abs.(eigen(J).values))

    return J .* JScale
end

function confocal_cos_J(w, N; rng=RandomDevice(), ϕ=0, s=0.1, JScale=1.0, onDiag=false, useK=false, KFrac=0.5)
    if useK
        error("K matrix not implemented for this J function")
    end

    pos = w * randn(rng, N, 2)
    x = pos[:, 1]
    y = pos[:, 2]

    J = zeros(N, N)
    for i = 1:N
        for j = i:N
            J[i, j] = ( D_local_even(x[i], x[j], y[i], y[j], ϕ, s, s)
                        + D_nonlocal_even(x[i], x[j], y[i], y[j], ϕ, s, s) )
            J[j, i] = J[i, j]
        end
    end
    if !onDiag
        J = J - diagm(diag(J))
    end
    J ./= maximum(abs.(eigen(J).values))

    return J .* JScale
end

# Observables related functions
# ========================================= #

# Here spins should be a single vector.
function magnetization(spins)
    return float(sum(spins)) / length(spins)
end

# Here spins should be a single vector.
function overlap(spins1, spins2)
    return float(dot(spins1, spins2)) / length(spins1)
end

function magTensor(θ)

    N = length(θ)
    x = cos.(θ)
    y = sin.(θ)

    magX = sum(x)/N
    magY = sum(y)/N

    return [magX, magY]
end

# This omits the XY and YX components; prior versions kept them.
function overlapTensor(θ1, θ2)

    N = length(θ1)
    x1 = cos.(θ1)
    y1 = sin.(θ1)

    x2 = cos.(θ2)
    y2 = sin.(θ2)

    overlap = zeros(2)

    overlap[1] = dot(x1, x2)/N
    overlap[2] = dot(y1, y2)/N

    return overlap
end

# Rotate the overlap tensor 45 degrees for consistency with other data.
function overlapTensorRotated(θ1, θ2)

    N = length(θ1)
    x1 = cos.(θ1)
    y1 = sin.(θ1)

    x2 = cos.(θ2)
    y2 = sin.(θ2)

    xx = dot(x1, x2)/N
    yy = dot(y1, y2)/N

    overlap = zeros(2)

    overlap[1] = xx + yy
    overlap[2] = yy - xx

    return overlap
end

# Repeated comment: the energy is always -sum_{i <= j} J_ij f(s_i, s_j); this convention
# for J_ij may differ from others by a factor of 2 and has caused us headaches in the past.
function IsingEnergy(spins, J)
    E = 0
    N = length(spins)
    for i = 1:N
        for j = i:N
            E -= J[i, j] * spins[i] * spins[j]
        end
    end
    return E
end

function motionalEnergy(spins, inputJ)
    J = inputJ
    K = zeros(size(inputJ))
    if ndims(inputJ) == 3
        J = inputJ[:, :, 1]
        K = inputJ[:, :, 2]
    end

    E = 0
    N = length(spins)
    @inbounds for i = 1:N
        @inbounds for j = i:N
            E -= J[i, j] * cos(spins[i]+spins[j])
            E -= K[i, j] * sin(spins[i]+spins[j])
        end
    end
    return E
end

# These functions used to be hardcoded into PT_poll_save, but they
# are separated out here to allow for motional versions.
function mAbsMeans_Ising(mags, i)
    xvals = abs.(range(-1, 1, length=length(mags[:, i]))) / sum(mags[:, i])
    return [dot(mags[:, i], xvals)]
end

function qBinders_Ising(overlaps, i)
    return [BinderCumulant(overlaps[:, i])]
end

# The following two functions find Binder cumulants of marginals of the joint histogram.
function mMarginalBinders_mo(mags, i)
    mx_marginal = sum(mags[:, :, i]; dims=2)[:]
    my_marginal = sum(mags[:, :, i]; dims=1)[:]
    return [BinderCumulant(mx_marginal), BinderCumulant(my_marginal)]
end

function qMarginalBinders_mo(overlaps, i)
    qxx_marginal = sum(overlaps[:, :, i]; dims=2)[:]
    qyy_marginal = sum(overlaps[:, :, i]; dims=1)[:]
    return [BinderCumulant(qxx_marginal), BinderCumulant(qyy_marginal)]
end

# Parallel tempering related
# ========================================= #

#= spins is a 2D array, each column is a separate state.
Modifies states in-place.
betas are inverse temperatures =#
@views function Metropolis!(spins, betas, J, steps)

    # Get sizes of things
    N, numTemps = size(spins)

    # Do a bunch of Metropolis steps
    @inbounds for step = 1:steps
        # Compute spin flip energies and flip spins
        @inbounds for i = 1:numTemps
            idx = rand(1:N)
            # See if the random number is less than the Boltzmann weight
            if rand() < exp(-betas[i]*2*(spins[idx, i]*dot(J[:, idx], spins[:, i]) - J[idx, idx]))
                spins[idx, i] *= -1
            end
        end
    end
end
# Metropolis test. Should be zero allocations. ~16 us runtime.
#=
N = 10; numTemps=3; spins = sign.(0.5 .- rand(N, numTemps)); J=randn(N, N); J=J+J'; betas = abs.(5*rand(numTemps)); steps=100;
@btime Metropolis!(spins, betas, J, steps)
=#

#= Swapping function for parallel tempering
isEven = true for even swaps, false for odd swaps
see https://en.wikipedia.org/wiki/Parallel_tempering =#
@views function swap!(spins, betas, J, isEven; Efunc=IsingEnergy)

    # Get sizes of things
    N, numTemps = size(spins)

    # Do swaps
    offset = isEven ? 0 : 1
    @inbounds for idx = (1+offset):2:(numTemps - 1)

        E1 = Efunc(spins[:, idx+1], J)
        E2 = Efunc(spins[:, idx]  , J)

        if rand() < exp( (E1-E2) * (betas[idx+1]-betas[idx]) )
            # Swap it!
            @inbounds for k = 1:N
                tmp = spins[k, idx+1]
                spins[k, idx+1] = spins[k, idx]
                spins[k, idx] = tmp
            end
        end
    end
end
# swap test. Should be zero allocations. ~380 ns runtime.
#=
N = 10; spins = sign.(0.5 .- rand(N, 5)); J=randn(N, N); J=J+J'; betas = rand(5); isEven=true;
@btime swap!(spins, betas, J, isEven)
=#

#= Parallel tempering. mags and overlaps are 2D arrays of size (numVals, numTemps). Each column
is a histogram. numTemps refers to the number of temperatures used. =#
@views function PT_basic!(spins1, spins2, betas, J, mags, overlaps, rounds; MetropolisFunc=Metropolis!,
    magFunc=magnetization, overlapFunc=overlap, swapFunc=swap!)

    # Get sizes of things
    N, numTemps = size(spins1)
    magLen = size(mags)[1]
    overlapLen = size(overlaps)[1]

    isEven = true
    for k = 1:rounds

        MetropolisFunc(spins1, betas, J, N)
        MetropolisFunc(spins2, betas, J, N)

        # Measure stuff and record
        for i = 1:numTemps
            mag1 = magFunc(spins1[:, i])
            mag2 = magFunc(spins2[:, i])
            overlap = overlapFunc(spins1[:, i], spins2[:, i])

            idx = Int(round( 1 + (mag1+1)/2*(magLen-1) ))
            mags[idx, i] += 1

            idx = Int(round( 1 + (mag2+1)/2*(magLen-1) ))
            mags[idx, i] += 1

            idx = Int(round( 1 + (overlap+1)/2*(overlapLen-1) ))
            overlaps[idx, i] += 1
        end

        swapFunc(spins1, betas, J, isEven)
        swapFunc(spins2, betas, J, isEven)

        isEven = !isEven # Deterministic even/odd swapping method
    end
end
# PT_basic test. Should be zero allocations. ~60 us runtime.
#=
N = 10; numTemps=5; spins1 = sign.(0.5 .- rand(N, numTemps)); spins2 = sign.(0.5 .- rand(N, numTemps));
J = randn(N, N); J = J + J'; betas = abs.(5*rand(numTemps));
mags = zeros(5, numTemps); overlaps = zeros(5, numTemps);
@btime PT_basic!(spins1, spins2, betas, J, mags, overlaps, 10)
=#

#= Parallel tempering, for use at scale. It does the same as above but also keeps track
of the mean of the absolute value of magnetization and the Binder cumulant of the overlap at
the input times (allPollTimes). It writes to the output file at each poll time and at
regular periods based on input saveInterval. =#
@views function PT_poll_save!(spins1, spins2, betas, J, mags, overlaps, mPollData, qPollData,
    polledTimes, rounds, allPollTimes, saveInterval, outFile, roundsPerState; MetropolisFunc=Metropolis!,
        magFunc=magnetization, overlapFunc=overlap, swapFunc=swap!, mPollDataFunc=mAbsMeans_Ising, qPollDataFunc=qBinders_Ising)

    # Get sizes of things
    N, numTemps = size(spins1)
    magLen = size(mags)[1]
    overlapLen = size(overlaps)[1]
    prior_steps = (size(polledTimes)[1] > 0) ? polledTimes[end] : 0

    states = zeros(Float64, 1, N, numTemps, 2)
    statesInitialized = false

    isEven = true
    for step = (prior_steps + 1):(prior_steps + rounds)

        MetropolisFunc(spins1, betas, J, N)
        MetropolisFunc(spins2, betas, J, N)

        # Measure stuff and record
        for i = 1:numTemps
            mag1 = magFunc(spins1[:, i])
            mag2 = magFunc(spins2[:, i])
            overlap = overlapFunc(spins1[:, i], spins2[:, i])

            hist_idx = Tuple([Int.(round.( 1 .+ (mag1.+1)./2 .* (magLen-1) )); i])

            #idx = Int(round( 1 + (mag1+1)/2*(magLen-1) ))
            mags[CartesianIndex(hist_idx)] += 1

            #idx = Int(round( 1 + (mag2+1)/2*(magLen-1) ))
            hist_idx = Tuple([Int.(round.( 1 .+ (mag2.+1)./2 .* (magLen-1) )); i])
            mags[CartesianIndex(hist_idx)] += 1

            hist_idx = Tuple([Int.(round.( 1 .+ (overlap.+1)./2 .* (overlapLen-1) )); i])
            #print(hist_idx)
            overlaps[CartesianIndex(hist_idx)] += 1
        end

        swapFunc(spins1, betas, J, isEven)
        swapFunc(spins2, betas, J, isEven)

        isEven = !isEven # Deterministic even/odd swapping method

        # If we are at a step where we should store a state, add it to the states matrix
        if roundsPerState > 0 && step % roundsPerState == 0
            if !statesInitialized
                states[1, :, :, 1] = spins1
                states[1, :, :, 2] = spins2
                statesInitialized = true
            else
                states = cat(states, zeros(Float64, 1, N, numTemps, 2); dims=1)
                states[end, :, :, 1] = spins1
                states[end, :, :, 2] = spins2
            end
        end

        # If we are at a polling step, store polling data related to m and q;
        # do not store the full histograms (as this gets too big). Also do this at the end.
        if step in allPollTimes || step == prior_steps + rounds
            push!(polledTimes, step)
            # Create a new row in these tables, unless it is the first time we are storing
            # data in them (at which point they are numtemps x 1).
            if size(polledTimes)[1] > 1
                mPollData = cat(mPollData, mPollData[:, :, end]; dims=3)
                qPollData = cat(qPollData, qPollData[:, :, end]; dims=3)
            end
            for i = 1:numTemps
                mPollData[i, :, end] = mPollDataFunc(mags, i)
                qPollData[i, :, end] = qPollDataFunc(overlaps, i)
            end
        end

        # Save our work to disk at each save interval. Not at each poll time (so information
        # may be lost on interrupt) because then file I/O becomes the bottleneck, especially
        # at early times.
        if step % saveInterval == 0 || step == prior_steps + rounds
            npzwrite(outFile, Dict("mags" => mags, "overlaps" => overlaps,
                "mPollData" => mPollData, "qPollData" => qPollData,
                "polledTimes" => polledTimes, "spins1" => spins1, "spins2" => spins2, "states" => states))
        end
    end
    # We already wrote to disk at the end; nothing to do here anymore!
    # Presumably Julia can optimize away the check for step == prior_steps+rounds at each
    # step.
end


# Computes the swapping probability for a given set of βs = 1/temperatures
# Ideally all probablilities fall within a range ~[0.3, 0.7] or so.
@views function swapProbs(betas, J; rounds=200, prepRounds=10, motional=false)

    MetropolisFunc = Metropolis!
    swapFunc = swap!
    Efunc = IsingEnergy
    if motional
        MetropolisFunc = Metropolis_mo!
        swapFunc = swap_mo!
        Efunc = motionalEnergy
    end
    N = size(J)[1]
    numTemps = size(betas)[1]

    # Generate some spins and roughly thermalize them with some PT
    spins = Int.(sign.(randn(N, numTemps)))
    if motional
        spins = 2 .* π .* rand(N, numTemps)
    end
    @inbounds for i = 1:prepRounds
        MetropolisFunc(spins, betas, J, N)
        swapFunc(spins, betas, J, mod(i, 2)==0)
    end

    # Now measure swap probs as you do PT
    probs = zeros(numTemps - 1)
    Es = zeros(numTemps)
    @inbounds for i = 1:rounds

        # Rethermalize
        MetropolisFunc(spins, betas, J, N)

        # Compute energies and swap probs
        Es[1] = Efunc(spins[:, 1], J)
        @inbounds for j = 1:numTemps-1
            Es[j + 1] = Efunc(spins[:, j+1], J)
            probs[j] += min( 1 , exp( (Es[j]-Es[j+1]) * (betas[j]-betas[j+1]) ) )
        end

        # Do swaps
        swapFunc(spins, betas, J, mod(i, 2)==0)
    end
    probs ./= rounds

    return probs
end
# # Example swapProbs run that shows critical slowdown regime
#=
using PyPlot
N = 200
numTemps = 30
J = SK_J(0, N)
temps = geomspace(0.2, 5, numTemps)
probs = swapProbs(1 ./ temps, J, rounds=500, prepRounds=10)
fig = figure()
fill_between([temps[1], temps[numTemps]], [0.3, 0.3], [0.7, 0.7], color="tab:green", alpha=0.3)
semilogx(temps[1:numTemps-1], probs , ".-", label="a")
ylim(0, 1)
xlabel("Temps"); ylabel("Swap prob")
axvline(1, label=L"T_c", linestyle="--", alpha=0.5)
legend()
PyPlot.display_figs()
=#

# Motional model
# ========================================= #

# For motional physics. Spins are really a list of angles ∈[0, 2π]
@views function Metropolis_mo!(spins, betas, inputJ, steps; stepSize=π/8)

    # Get sizes of things
    N, numTemps = size(spins)

    J = inputJ
    K = zeros(size(inputJ))
    if ndims(inputJ) == 3
        J = inputJ[:, :, 1]
        K = inputJ[:, :, 2]
    end

    # Do a bunch of Metropolis steps
    @inbounds for step = 1:steps
        # Compute spin flip energies and flip spins
        @inbounds for i = 1:numTemps
            idx = rand(1:N)
            dθ = randn() * stepSize
            dE = 0
            θidx = spins[idx, i]
            @inbounds for k = 1:N
                θk = spins[k, i]
                if k == idx
                    # At some point these used to be done to avoid underflow, but it was changed
                    # back -- it made little effect on runtime or results.
                    dE -= J[idx, idx] * ( cos(2*θidx + 2*dθ) - cos(2*θidx) )
                    dE -= K[idx, idx] * ( sin(2*θidx + 2*dθ) - sin(2*θidx) )
                else
                    dE -= J[k, idx] * ( cos(θidx + θk + dθ) - cos(θidx + θk) )
                    dE -= K[k, idx] * ( sin(θidx + θk + dθ) - sin(θidx + θk) )
                end
            end

            # See if the random number is less than the Boltzmann weight
            if rand() < exp(-betas[i] * dE)
                spins[idx, i] += dθ
            end
        end
    end
    spins = mod2pi.(spins)
end
# Test of Metropolis_mo
#=
using PyPlot
N = 200
J = SK_J(0, N)
temp = [0.1]
polls = 500
Es = zeros(polls)
spins = rand(N, 1)*2π
for i = 1:polls
    Es[i] = motionalEnergy(spins[:, 1], Jl)
    Metropolis_mo!(spins, 1 ./temp, J, 100; stepSize=π/8)
end
fig=figure()
plot(Es)
PyPlot.display_figs()
=#

@views function swap_mo!(spins, betas, J, isEven)
    swap!(spins, betas, J, isEven, Efunc=motionalEnergy)
end

@views function PT_basic_mo!(spins1, spins2, betas, J, mags, overlaps, rounds;
    MetropolisFunc=Metropolis_mo!, magFunc=magTensor, overlapFunc=overlapTensor, swapFunc=swap_mo!)
    PT_basic!(spins1, spins2, betas, J, mags, overlaps, rounds; MetropolisFunc=MetropolisFunc,
        magFunc=magFunc, overlapFunc=overlapFunc, swapFunc=swapFunc)
end

@views function PT_poll_save_mo!(spins1, spins2, betas, J, mags, overlaps, mPollData, qPollData,
    polledTimes, rounds, allPollTimes, saveInterval, outFile, roundsPerState; MetropolisFunc=Metropolis_mo!,
    magFunc=magTensor, overlapFunc=overlapTensorRotated, swapFunc=swap_mo!,
    mPollDataFunc=mMarginalBinders_mo, qPollDataFunc=qMarginalBinders_mo)
    PT_poll_save!(spins1, spins2, betas, J, mags, overlaps, mPollData, qPollData,
        polledTimes, rounds, allPollTimes, saveInterval, outFile, roundsPerState; MetropolisFunc=MetropolisFunc,
            magFunc=magFunc, overlapFunc=overlapFunc, swapFunc=swapFunc,
            mPollDataFunc=mPollDataFunc, qPollDataFunc=qPollDataFunc)
end

@views function swapProbs_mo(betas, J; rounds=200, prepRounds=10)
    return swapProbs(betas, J; rounds=rounds, prepRounds=prepRounds, motional=true)
end
# # Example for swapProbs_mo
#=
using PyPlot
N = 100
# function SK_J(μ, N, rng=RandomDevice(), JScale=1.0, onDiag=false, useK=false, KFrac=0.5)
J = SK_J(0.0, N, JScale=2.0, useK=true, onDiag=true)
# temps = geomspace(0.01, 2., 50)
temps = [0.1, 0.115, 0.133, 0.154, 0.178, 0.205, 0.237, 0.274, 0.316, 0.365, 0.422, 0.487, 0.562, 0.649, 0.75, 0.866, 1.0, 1.15, 1.4, 2.0, 3.0, 5.0]
numTemps = size(temps)[1]
probs = swapProbs_mo(1 ./ temps, J, rounds=500, prepRounds=100)
fig = figure()
fill_between([temps[1], temps[numTemps]], [0.3, 0.3], [0.7, 0.7], color="tab:green", alpha=0.3)
xs = sqrt.(temps[1:numTemps-1] .* temps[2:numTemps])
semilogx(xs, probs, ".-", label="p")
ylim(0, 1)
xlabel("T (GM of successive pairs)"); ylabel("Swap probability")
axvline(1.0, label=L"T_c(?)", linestyle="--", alpha=0.5)
legend()
PyPlot.display_figs()
=#